# -*- coding: utf-8 -*-
"""
Created on Thu Apr 16 16:11:25 2020

@author: xiang_yaobing
"""

import os 
from PIL import Image
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM,Dense,Activation,SimpleRNN,Conv2D,MaxPool2D,Flatten,Reshape,Dropout
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.metrics import categorical_accuracy
from tensorflow.keras.optimizers import RMSprop
import numpy as np
from tensorflow.keras import utils
from sklearn.model_selection import train_test_split

def load_data(path1, path2, path3, path4):
#读入文件列表，path代表2个类别的文件路径
    filelist1 = [os.path.join(path1, f) for f in os.listdir(path1)]
    filelist2 = [os.path.join(path2, f) for f in os.listdir(path2)]
    filelist3 = [os.path.join(path3, f) for f in os.listdir(path3)]
    filelist4 = [os.path.join(path4, f) for f in os.listdir(path4)]
    x_test = []
    y_test = []
    n1 = len(filelist1)
    n2 = len(filelist3)
    print(n2)
    for i in range(n1):
        im1 = np.array(Image.open(filelist1[i]))#读图像
        im2 = np.array(Image.open(filelist2[i]))#读图像
        im = np.reshape(np.array([im1,im2]),(im1.shape[0],im1.shape[1],2))
        #im = np.array(np.loadtxt(img))
        #im = im.flatten()
        x_test.append(im)
    for j in range(n2):
        im1 = np.array(Image.open(filelist3[j]))
        im2 = np.array(Image.open(filelist4[j]))#读图像
        im = np.reshape(np.array([im1,im2]),(im1.shape[0],im1.shape[1],2))
        #print(img)
        #im = np.array(np.loadtxt(img))
        #im = im.flatten()
        x_test.append(im)
    x_test = np.array(x_test)
    #自己造标签 总共2类，所以标签是01
    y_test = np.zeros((n1 + n2), dtype=int)
    for i in range(n1): 
        y_test[i] = 0#不存在疏散星团
    for i in range(n2):
        y_test[n1 + i] = 1
    return x_test, y_test


#%%
model = Sequential()
# 第一个卷积层，32个卷积核，大小５x5，卷积模式SAME,激活函数relu,输入张量的大小
model.add(Conv2D(filters= 6, kernel_size=(5,5), padding='Same', activation='relu',input_shape=(60,60,2)))
#model.add(Conv2D(filters= 6, kernel_size=(5,5), padding='Same', activation='relu'))
# 池化层,池化核大小２x2
model.add(MaxPool2D(pool_size=(2,2)))
# 随机丢弃四分之一的网络连接，防止过拟合
model.add(Conv2D(filters= 16, kernel_size=(3,3), padding='Same', activation='relu'))
#model.add(Conv2D(filters= 16, kernel_size=(3,3), padding='Same', activation='relu'))

model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
#
#model.add(Conv2D(filters= 16, kernel_size=(3,3), padding='Same', activation='relu'))
#model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
# 全连接层,展开操作，
model.add(Flatten())
# 添加隐藏层神经元的数量和激活函数
model.add(Dense(64, activation='relu'))    
model.add(Dropout(0.5))
model.add(Dense(32, activation='relu'))    
model.add(Dropout(0.5))

# 输出层
model.add(Dense(2, activation='softmax'))  

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())
if __name__ == "__main__":
    
    nb_classes = 2
    path1 = 'E:\\疏散星团论文修改处理\\trainset\\NoClusterKDEpicture\\csv_dataset_no_cluster\\loc0.1\\'
    path2 = 'E:\\疏散星团论文修改处理\\trainset\\NoClusterKDEpicture\\csv_dataset_no_cluster\\loc\\'
    path3 = 'E:\\疏散星团论文修改处理\\trainset\\ClusterKDEpicture\\csv_dataset\\loc0.1\\'
    path4 = 'E:\\疏散星团论文修改处理\\trainset\\ClusterKDEpicture\\csv_dataset\\loc\\'
    data, labels = load_data(path1, path2, path3, path4)
    x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.01)
    y_train = utils.to_categorical(y_train)
    y_test = utils.to_categorical(y_test)
    x_train = np.array(x_train)
    x_test = np.array(x_test)
    model.fit(x_train, y_train, validation_split=0.1, verbose=1, batch_size=50, epochs=25)
    model.save('cnn_model_loc_v1.h5')